library(tidyverse)
library(maptools)
library(igraph)
library(parallel)
library(redist)
library(spdep)
library(RColorBrewer)
library(gridExtra)
gpclibPermit()
library(sf)
library(ggrepel)
library(patchwork)

ipw <- function(x, beta, pop){
    indpop <- which(x$distance_parity <= pop)
    indbeta <- which(x$beta_sequence == beta)
    inds <- intersect(indpop, indbeta)
    psi <- x$constraint_pop[inds]
    w <- 1 / exp(beta * psi)
    inds <- sample(inds, length(inds), replace = TRUE, prob = w)
    x <- x$partitions[,inds]
    return(x)
}
ben_theme <- function(){
    theme_classic() +
        theme(panel.background = element_blank(),
              panel.grid.major = element_blank(),
              panel.grid.minor = element_blank(),
              axis.line = element_line(colour = "black"),
              panel.border = element_rect(colour = "black", fill = NA, size = 1),
              strip.background = element_blank(),
              legend.position = "bottom", legend.title = element_blank(),
              plot.title = element_text(hjust = 0.5),
              plot.subtitle = element_text(hjust = .5))
}

## Load iowa map
ia <- readShapePoly("../data/ia/county.shp")
votes <- read_csv("../data/ia/ia_vote.csv")
pops <- read_csv("../data/ia/ia_pop.csv") %>%
    mutate(County = gsub(" County", "", County))
cds <- read_csv("../data/ia/ia_cds.csv")
info <- inner_join(votes, pops) %>%
    rename(county = County,
           trump = `Trump #`,
           clinton = `Clinton #`,
           other = `Others #`,
           total = `Total`,
           fips = `FIPS code[10]`,
           pop = Population) %>%
    select(county, trump, clinton, fips, pop) %>%
    inner_join(cds, by = c("county" = "COUNTY"))

## Join with shapefile
ia@data$COUNTY <- as.character(ia@data$COUNTY)
ia@data$COUNTY[ia@data$COUNTY == "Obrien"] <- "O'Brien"
ia@data$id <- 1:nrow(ia@data)
ia@data <- inner_join(ia@data, info, by = c("COUNTY" = "county"))

## --------------
## Load solutions
## --------------
sols <- do.call("rbind", mclapply(1:50, function(x){
    dat <- read_csv(paste0("/n/imai_lab/bfifield/enumeration/ia_enumerate/dissimilarity_", x, ".csv"))
}, mc.cores = 40)) 

## ----------
## Run redist
## ----------
pop <- ia@data$pop
rep <- ia@data$trump
nsims <- 250000
nchains <- 8

## Unconstrained
seg_unc <- unlist(
    mclapply(1:nchains, function(x){
        out_unc <- redist.mcmc(ia, popvec = pop,
                               nsims = nsims, ndists = 4)
        return(coda::mcmc(redist.segcalc(out_unc, grouppop = rep, fullpop = pop)))
    }, mc.cores = detectCores())
)

## 5% parity
seg_05 <- unlist(
    mclapply(1:nchains, function(x){
        out_05 <- redist.mcmc(ia, popvec = pop,
                               nsims = nsims, ndists = 4,
                               constraint = "population", beta = -25)
        out_05 <- ipw(out_05, beta = -25, pop = .05)
        return(coda::mcmc(redist.segcalc(out_05, grouppop = rep, fullpop = pop)))
    }, mc.cores = detectCores())
)

## 1% parity
seg_01 <- unlist(
    mclapply(1:nchains, function(x){
        out_01 <- redist.mcmc(ia, popvec = pop,
                               nsims = nsims, ndists = 4,
                               constraint = "population", beta = -50)
        out_01 <- ipw(out_01, beta = -50, pop = .01)
        return(coda::mcmc(redist.segcalc(out_01, grouppop = rep, fullpop = pop)))
    }, mc.cores = detectCores())
)

## RSG
fn <- list.files("../data/ia_enumerate/", pattern="rsg_*")
rsg_sims <- do.call("rbind", mclapply(paste0("../data/ia_enumerate/", fn), read_csv, mc.cores = 20))

## Create dataframe for plot
df_plot <- data.frame(
    seg = c(sols$seg, sols$seg[sols$par <= .05], sols$seg[sols$par <= .01],
            seg_unc, seg_05, seg_01,
            rsg_sims$seg[rsg_sims$parity == "No Equal Population Constraint"],
            rsg_sims$seg[rsg_sims$parity == "5% Equal Population Constraint"],
            rsg_sims$seg[rsg_sims$parity == "1% Equal Population Constraint"]
            ),
    parity = c(rep("No Equal Population Constraint", nrow(sols)),
               rep("5% Equal Population Constraint", nrow(sols[sols$par <= .05,])),
               rep("1% Equal Population Constraint", nrow(sols[sols$par <= .01,])),
               rep("No Equal Population Constraint", length(seg_unc)),
               rep("5% Equal Population Constraint", length(seg_05)),
               rep("1% Equal Population Constraint", length(seg_01)),
               rep("No Equal Population Constraint", sum(rsg_sims$parity == "No Equal Population Constraint")),
               rep("5% Equal Population Constraint", sum(rsg_sims$parity == "5% Equal Population Constraint")),
               rep("1% Equal Population Constraint", sum(rsg_sims$parity == "1% Equal Population Constraint"))
               ),
    alg = c(rep("Truth", nrow(sols) + nrow(sols[sols$par <= .05,]) + nrow(sols[sols$par <= .01,])),
            rep("MCMC", length(seg_unc) + length(seg_05) + length(seg_01)),
            rep("RSG", nrow(rsg_sims)))
) %>%
    mutate(alg = factor(alg, levels = c("Truth", "MCMC", "RSG")),
           parity = factor(parity, levels = c("No Equal Population Constraint",
                                              "5% Equal Population Constraint",
                                              "1% Equal Population Constraint")))

## write_csv(df_plot[df_plot$alg != "Truth",], path = "../data/ia_enumerate/mcmc_out.csv")
## mcmc_sims <- read_csv("../data/ia_enumerate/mcmc_out.csv")
## truth_sims <- data.frame(seg = c(sols$seg, sols$seg[sols$par <= .05], sols$seg[sols$par <= .01]),
##                          parity = c(rep("No Equal Population Constraint", nrow(sols)),
##                                     rep("5% Equal Population Constraint", nrow(sols[sols$par <= .05,])),
##                                     rep("1% Equal Population Constraint", nrow(sols[sols$par <= .01,]))),
##                          alg = "Truth")
## rm(sols)
## df_plot <- bind_rows(mcmc_sims, truth_sims) %>%
##     mutate(alg = factor(alg, levels = c("Truth", "MCMC", "RSG")),
##            parity = factor(parity, levels = c("No Equal Population Constraint",
##                                               "5% Equal Population Constraint",
##                                               "1% Equal Population Constraint")))
## rm(mcmc_sims, truth_sims)

ggplot(df_plot, aes(seg, colour = alg, fill = alg,
                    alpha = alg, lty = alg)) +
    geom_density(lwd = 1.1) +
    facet_grid(~ parity) +
    theme_classic() +
    scale_colour_manual(values = c("grey", "black", "red")) +
    scale_fill_manual(values = c("grey", "white", "white")) +
    scale_linetype_manual(values = c(1, 1, 2)) + 
    scale_alpha_manual(values = c(1, 0, 0)) + 
    labs(x = "Republican Dissimilarity Index", y = "Density") +
    ben_theme() +
    theme(legend.position = c(.1, .75)) +
    ## guides(colour = guide_legend(nrow = 1, byrow = TRUE)) +
    ggsave("../paper/figs/ia_enumerate_validation.pdf", height = 3, width = 8)

## ----------
## Create map
## ----------
ia@data$id <- rownames(ia@data)
ia_points <- fortify(ia, region = "id")
ia_df <- plyr::join(ia_points, ia@data, by = "id")

cities <- data.frame(x = c(452378.2, 615964.6),
                     y = c(4613623, 4659540),
                     label = c("Des Moines", "Cedar Rapids"))

# map_plot <- ggplot(data = ia_df) + 
#     aes(long, lat, group = group, alpha = pop/1000, fill = as.factor(DISTRICT)) + 
#     geom_polygon(colour = "black", size = .3) +
#     geom_path(color="black", lwd = .05) + 
#     theme_classic() +
#     scale_fill_manual(values = brewer.pal(4, "Dark2")) +
#     coord_equal() +
#     scale_alpha_continuous(labels = scales::comma) + 
#     guides(fill = guide_legend(title = "Iowa CDs", label.position = "bottom")) +
#     guides(alpha = guide_legend(title = "Population\n(Thousands)", label.position = "bottom")) +
#     labs(x = "", y = "",
#          title = "Congressional Districts of Iowa (2016)") +
#     annotate("point", x = cities$x, y = cities$y, size = 2) + 
#     annotate("text", x = cities$x, y = cities$y + 15000,
#              label = cities$label, size = 6) +
#     theme(panel.border = element_rect(colour = "black", fill=NA),
#           plot.title = element_text(hjust = 0.5),
#           legend.position = "none",
#           legend.box = "vertical",
#           legend.title = element_text(size = 12), 
#           legend.text = element_text(size = 12),
#           axis.ticks = element_blank(),
#           axis.text = element_blank())
# 
# ## Create legend for map
# df_legend <- expand.grid(district = 1:4,
#                          population = seq(100, 400, by = 100))
# coords <- expand.grid(x = 0:3, y = 0:3)
# df_legend$x <- coords$x
# df_legend$y <- coords$y
# 
# legend <- ggplot(df_legend, aes(x, y, fill = as.factor(district),
#                                 alpha = population)) +
#     geom_raster() + coord_flip() +
#     scale_fill_manual(values = brewer.pal(4, "Dark2")) +
#     scale_x_continuous(breaks = c(0, 1, 2, 3),
#                        labels = c("District 1", "District 2",
#                                   "District 3", "District 4")) + 
#     scale_y_continuous(breaks = c(0, 1, 2, 3),
#                        label = c("100", "200", "300", "400")) + 
#     theme_minimal() +
#     labs(x = "", y = "Population (Thousands)") + 
#     theme(line = element_blank(),
#           legend.position = "none",
#           axis.text = element_text(size = 13),
#           axis.title = element_text(size = 13))

## Updated Map
ia_sf <- sf::st_as_sf(ia)
ia_cities <- tibble(label = cities$label)
geom <- st_sfc(st_point(c(cities$x[1],cities$y[1])), st_point(c(cities$x[2],cities$y[2])))
ia_cities <- cbind(ia_cities, geom)
ia_cities$pop <- 0
ia_dists <- ia_sf %>% group_by(DISTRICT) %>% summarise(geometry = st_union(geometry))

map_plot <- ia_sf %>% ggplot(aes(fill = pop/1000)) +
    geom_sf() +
    theme_classic() +
    scale_fill_gradient(low = '#ffffff', high = '#0c0c0c') +
    theme(panel.border = element_rect(colour = "white", fill=NA),
          plot.title = element_text(hjust = 0.5),
          #legend.box = "vertical",
          legend.position = 'bottom',
          legend.title = element_text(size = 12), 
          legend.text = element_text(size = 12),
          axis.ticks = element_blank(),
          axis.text = element_blank(),
          axis.line = element_blank()) + 
    labs(x='',y='',title = 'Congressional Districts of Iowa (2016)') + 
    guides(fill = guide_legend(title = "Population\n(Thousands)", label.position = "bottom", )) + 
    geom_sf(data = nngeo::st_remove_holes(ia_dists), fill = NA, lwd = 2) +
    geom_sf(data = ia_cities, fill=NA, aes(geometry=geometry), color = 'red') +
    ggsflabel::geom_sf_text_repel(data = ia_cities, fill=NA, aes(geometry=geometry, label = label), color = 'red')# +
    #ggsave(filename = "../paper/figs/map_ia.png")

## Parity plot
parity_plot_df <- sols %>% filter(par <= .2) %>%
    mutate(
        parity_group = cut(par, breaks = seq(0, max(par), .01),
                           include.lowest= TRUE, labels = FALSE)
    ) %>%
    group_by(pg = parity_group / 100) %>%
    summarize(n = n()) %>%
    ungroup() %>%
    mutate(prop = n / nrow(sols))
numsols_df <- c(
    sum(sols$par <= .01)/nrow(sols),
    sum(sols$par <= .05)/nrow(sols),
    sum(sols$par <= .1)/nrow(sols)
)
text_out <- data.frame(text = paste0(paste0("< ", format(c(.01, .05, .10), nsmall = 2), " Parity: ", paste0(format(round(numsols_df, 7)*100, nsmall = 2), "%"), collapse = "\n")))

parity_plot <- ggplot(parity_plot_df, aes(pg - .005, prop)) +
    geom_bar(stat = "identity") +
    theme_classic() +
    geom_text(data = text_out, aes(x = .13, y = Inf, label = text),
              vjust = 1.2, hjust = 1.1) + 
    labs(x = "Distance from Population Parity",
         y = "Share of Sampled Plans",
         title = "Distribution of Population Parity on Iowa Map") +
    scale_y_continuous(labels = scales::percent_format(accuracy = .01)) + 
    theme(panel.border = element_rect(colour = "black", fill=NA),
          plot.title = element_text(hjust = 0.5))#,
          #legend.position = c(.05, .05)) + 

map_grob <- arrangeGrob(map_plot, legend, ncol = 1, heights = c(3, 1))
parity_grob <- arrangeGrob(parity_plot)

pdf(file = "../paper/figs/map_descriptives_ia.pdf", height = 5, width = 10)
grid.arrange(map_grob, parity_grob, widths = c(2.5, 2), nrow = 1, ncol = 2)
dev.off()


plot_out <- map_plot + parity_plot + plot_layout(widths = c(2.5, 2))
ggsave('map_descriptives_ia.pdf', plot_out,path = '../paper/figs/', height = 5, width = 10)
